#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import print_function
import os
import sys
import glob
import argparse
import logging
import time
import platform
from multiprocessing import cpu_count
from threading import Thread
from subprocess import run
from queue import Queue, Empty

TIMEOUT = .2

if platform.system() == "Windows":
    CMD = ["python3"]
else:
    CMD = []


def comment(text):
    newtext = []
    for line in text.split("\n"):
        if line.strip():  # skip blank or space-only lines
            if line.startswith("ok") or line.startswith("not ok"):
                line = "# (original state) " + line
            elif not line.startswith("# "):
                line = "# " + line

            newtext.append(line)

    return "\n".join(newtext)


def run_test(testqueue, outqueue, threadname):
    start = time.time()
    while True:
        try:
            test = testqueue.get(block=True, timeout=TIMEOUT)
        except Empty:
            break

        log.info("Running test %s", test)

        failed = False

        try:
            p = run(CMD + [os.path.abspath(test)], capture_output=True,
                    env=os.environ, text=True)
        except Exception as e:
            log.exception(e)
            failed = True
            reason = str(e)
            out = ""
            err = ""
        else:
            if p.returncode:
                log.info(
                    "Test %s exitcode was %s. Tests shouldn't use exit code. "
                    "They are expected to exit cleanly and output 'ok' or 'not ok'",
                    test, p.returncode
                )
                failed = True
                reason = "Exit code was {0}".format(p.returncode)
            out = p.stdout
            err = p.stderr

        test_head = "# {0}\n".format(os.path.basename(test))

        if failed:
            output = (test_head, comment(out), comment(err), "\nnot ok - {0}\n".format(reason))
        else:
            output = (test_head, out, err)

        log.debug("Collected output %s", output)
        outqueue.put(output)

        testqueue.task_done()

    log.warning("Finished %s thread after %s seconds",
                threadname, round(time.time() - start, 3))


class TestRunner(object):
    def __init__(self):
        self.threads = []
        self.tap = open(cmd_args.tapfile, 'w')

        self._parallelq = Queue()
        self._serialq = Queue()
        self._outputq = Queue()

    def _find_tests(self):
        for test in glob.glob("*.t"):
            if os.access(test, os.X_OK):
                # Executables only
                if self._is_parallelizable(test):
                    log.debug("Treating as parallel: %s", test)
                    self._parallelq.put(test)
                else:
                    log.debug("Treating as serial: %s", test)
                    self._serialq.put(test)
            else:
                log.debug("Ignored test %s as it is not executable", test)

        log.info("Parallel tests: %s", self._parallelq.qsize())
        log.info("Serial tests: %s", self._serialq.qsize())

    def _prepare_threads(self):
        # Serial thread
        self.threads.append(
            Thread(target=run_test, args=(self._serialq, self._outputq, "Serial"))
        )
        # Parallel threads
        self.threads.extend([
            Thread(target=run_test, args=(self._parallelq, self._outputq, "Parallel"))
            for i in range(cpu_count())
        ])
        log.info("Spawned %s threads to run tests", len(self.threads))

    def _start_threads(self):
        for thread in self.threads:
            # Threads die when main thread dies
            log.debug("Starting thread %s", thread)
            thread.daemon = True
            thread.start()

    def _print_timestamp_to_tap(self):
        now = time.time()
        timestamp = "# {0} ==> {1}\n".format(
            now,
            time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(now)),
        )

        log.debug("Adding timestamp %s to TAP file", timestamp)
        self.tap.write(timestamp)

    def _is_parallelizable(self, test):
        if cmd_args.serial:
            return False

        # This is a pretty weird way to do it, and not realiable.
        # We are dealing with some binary tests though.
        with open(test, 'rb') as fh:
            header = fh.read(100).split(b"\n")
            if (len(header) >= 2 and (b"/usr/bin/env python3" in header[0] or
                                      b"bash_tap" in header[1])):
                return True
            else:
                return False

    def _get_remaining_tests(self):
        return self._parallelq.qsize() + self._serialq.qsize()

    def is_running(self):
        for thread in self.threads:
            if thread.is_alive():
                return True

        return False

    def start(self):
        self._find_tests()
        self._prepare_threads()

        self._print_timestamp_to_tap()

        finished = 0
        total = self._get_remaining_tests()

        self._start_threads()

        while self.is_running() or not self._outputq.empty():
            try:
                outputs = self._outputq.get(block=True, timeout=TIMEOUT)
            except Empty:
                continue

            log.debug("Outputting to TAP: %s", outputs)

            for output in outputs:
                self.tap.write(output)

                if cmd_args.verbose:
                    sys.stdout.write(output)

            self._outputq.task_done()
            finished += 1

            log.warning("Finished %s out of %s tests", finished, total)

        self._print_timestamp_to_tap()

        if not self._parallelq.empty() or not self._serialq.empty():
            raise RuntimeError(
                "Something went wrong, not all tests were ran. {0} "
                "remaining.".format(self._get_remaining_tests()))

    def show_report(self):
        self.tap.flush()
        sys.stdout.flush()
        sys.stderr.flush()

        log.debug("Calling 'problems --summary' for report")
        return run(CMD + [os.path.abspath("problems"), "--summary", cmd_args.tapfile]).returncode


def parse_args():
    parser = argparse.ArgumentParser(description="Run tests")
    parser.add_argument('--verbose', '-v', action="store_true",
                        help="Also send TAP output to stdout")
    parser.add_argument('--logging', '-l', action="count",
                        default=0,
                        help="Logging level. -lll is the highest level")
    parser.add_argument('--serial', action="store_true",
                        help="Do not run tests in parallel")
    parser.add_argument('--tapfile', default="all.log",
                        help="File to use for TAP output")
    return parser.parse_args()


def main():
    runner = TestRunner()
    runner.start()

    # Propagate the return code
    return runner.show_report()


if __name__ == "__main__":
    cmd_args = parse_args()

    if cmd_args.logging == 1:
        level = logging.WARN
    elif cmd_args.logging == 2:
        level = logging.INFO
    elif cmd_args.logging >= 3:
        level = logging.DEBUG
    else:
        level = logging.ERROR

    logging.basicConfig(
        format="# !!! %(asctime)s - %(levelname)s - %(message)s",
        level=level,
    )
    log = logging.getLogger(__name__)

    log.debug("Parsed commandline arguments: %s", cmd_args)

    try:
        sys.exit(main())
    except Exception as e:
        log.exception(e)
        sys.exit(1)

# vim: ai sts=4 et sw=4
